{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K2s1A9eLRPEj"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "VRLVEKiTEn04"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Cffg2i257iMS"
      },
      "source": [
        "# ビジュアルアテンションを用いた画像キャプショニング\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/text/image_captioning\">\n",
        "    <img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />\n",
        "    View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/tutorials/text/image_captioning.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/tutorials/text/image_captioning.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/tutorials/text/image_captioning.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RoRi17kuCYFs"
      },
      "source": [
        "Note: これらのドキュメントは私たちTensorFlowコミュニティが翻訳したものです。コミュニティによる 翻訳は**ベストエフォート**であるため、この翻訳が正確であることや[英語の公式ドキュメント](https://www.tensorflow.org/?hl=en)の 最新の状態を反映したものであることを保証することはできません。 この翻訳の品質を向上させるためのご意見をお持ちの方は、GitHubリポジトリ[tensorflow/docs](https://github.com/tensorflow/docs)にプルリクエストをお送りください。 コミュニティによる翻訳やレビューに参加していただける方は、 [docs-ja@tensorflow.org メーリングリスト](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-ja)にご連絡ください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QASbY_HGo4Lq"
      },
      "source": [
        "下記のような画像をもとに、\"a surfer riding on a wave\" のようなキャプションを生成することをめざします。\n",
        "\n",
        "![Man Surfing](https://tensorflow.org/images/surf.jpg)\n",
        "\n",
        "_[画像ソース](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg); ライセンス: パブリックドメイン_\n",
        "\n",
        "これを達成するため、アテンションベースのモデルを用います。これにより、キャプションを生成する際にモデルが画像のどの部分に焦点を当てているかを見ることができます。\n",
        "\n",
        "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n",
        "\n",
        "モデルの構造は、[Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044) と類似のものです。\n",
        "\n",
        "このノートブックには例として、一連の処理の最初から最後までが含まれています。このノートブックを実行すると、[MS-COCO](http://cocodataset.org/#home) データセットをダウンロードし、Inception V3 を使って画像のサブセットを前処理し、キャッシュします。その後、エンコーダー・デコーダーモデルを訓練し、訓練したモデルを使って新しい画像のキャプションを生成します。\n",
        "\n",
        "この例では、比較的少量のデータ、およそ 20,000 枚の画像に対する最初の 30,000 のキャプションを使ってモデルを訓練します（データセットには 1 枚の画像あたり複数のキャプションがあるからです）。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U8l4RJ0XRPEm"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "# モデルがキャプション生成中に画像のどの部分に注目しているかを見るために\n",
        "# アテンションのプロットを生成\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Scikit-learn には役に立つさまざまなユーティリティが含まれる\n",
        "from sklearn.model_selection import train_test_split\n",
        "from sklearn.utils import shuffle\n",
        "\n",
        "import re\n",
        "import numpy as np\n",
        "import os\n",
        "import time\n",
        "import json\n",
        "from glob import glob\n",
        "from PIL import Image\n",
        "import pickle"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b6qbGw8MRPE5"
      },
      "source": [
        "## MS-COCO データセットのダウンロードと準備\n",
        "\n",
        "モデルの訓練には [MS-COCO データセット](http://cocodataset.org/#home) を使用します。このデータセットには、82,000 枚以上の画像が含まれ、それぞれの画像には少なくとも 5 つの異なったキャプションがつけられています。下記のコードは自動的にデータセットをダウンロードして解凍します。\n",
        "\n",
        "**Caution: 巨大ファイルのダウンロードあり** 訓練用のデータセットは、13GBのファイルです。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "krQuPYTtRPE7"
      },
      "outputs": [],
      "source": [
        "annotation_zip = tf.keras.utils.get_file('captions.zip',\n",
        "                                          cache_subdir=os.path.abspath('.'),\n",
        "                                          origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n",
        "                                          extract = True)\n",
        "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n",
        "\n",
        "name_of_zip = 'train2014.zip'\n",
        "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n",
        "  image_zip = tf.keras.utils.get_file(name_of_zip,\n",
        "                                      cache_subdir=os.path.abspath('.'),\n",
        "                                      origin = 'http://images.cocodataset.org/zips/train2014.zip',\n",
        "                                      extract = True)\n",
        "  PATH = os.path.dirname(image_zip)+'/train2014/'\n",
        "else:\n",
        "  PATH = os.path.abspath('.')+'/train2014/'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aANEzb5WwSzg"
      },
      "source": [
        "## オプション： 訓練用データセットのサイズ制限\n",
        "\n",
        "このチュートリアルの訓練をスピードアップするため、サブセットである 30,000 のキャプションと対応する画像を使ってモデルを訓練します。より多くのデータを使えばキャプションの品質が向上するでしょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4G3b8x8_RPFD"
      },
      "outputs": [],
      "source": [
        "# json ファイルの読み込み\n",
        "with open(annotation_file, 'r') as f:\n",
        "    annotations = json.load(f)\n",
        "\n",
        "# ベクトルにキャプションと画像の名前を格納\n",
        "all_captions = []\n",
        "all_img_name_vector = []\n",
        "\n",
        "for annot in annotations['annotations']:\n",
        "    caption = '<start> ' + annot['caption'] + ' <end>'\n",
        "    image_id = annot['image_id']\n",
        "    full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n",
        "\n",
        "    all_img_name_vector.append(full_coco_image_path)\n",
        "    all_captions.append(caption)\n",
        "\n",
        "# captions と image_names を一緒にシャッフル\n",
        "# random_state を設定\n",
        "train_captions, img_name_vector = shuffle(all_captions,\n",
        "                                          all_img_name_vector,\n",
        "                                          random_state=1)\n",
        "\n",
        "# シャッフルしたデータセットから最初の 30,000 のキャプションを選択\n",
        "num_examples = 30000\n",
        "train_captions = train_captions[:num_examples]\n",
        "img_name_vector = img_name_vector[:num_examples]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mPBMgK34RPFL"
      },
      "outputs": [],
      "source": [
        "len(train_captions), len(all_captions)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8cSW4u-ORPFQ"
      },
      "source": [
        "## InceptionV3 を使った画像の前処理\n",
        "\n",
        "つぎに、（Imagenet を使って訓練済みの）InceptionV3をつかって、画像を分類します。最後の畳み込み層から特徴量を抽出します。\n",
        "\n",
        "最初に、つぎのようにして画像を InceptionV3 が期待するフォーマットに変換します。\n",
        "* 画像を 299 ピクセル × 299 ピクセルにリサイズ\n",
        "* [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) メソッドをつかって画像を InceptionV3 の訓練用画像のフォーマットに合致した −1 から 1 の範囲のピクセルを持つ形式に標準化して[画像の前処理](https://cloud.google.com/tpu/docs/inception-v3-advanced#preprocessing_stage) を行う"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zXR0217aRPFR"
      },
      "outputs": [],
      "source": [
        "def load_image(image_path):\n",
        "    img = tf.io.read_file(image_path)\n",
        "    img = tf.image.decode_jpeg(img, channels=3)\n",
        "    img = tf.image.resize(img, (299, 299))\n",
        "    img = tf.keras.applications.inception_v3.preprocess_input(img)\n",
        "    return img, image_path"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MDvIu4sXRPFV"
      },
      "source": [
        "## InceptionV3 を初期化し Imagenet で学習済みの重みをロードする\n",
        "\n",
        "ここで、出力層が InceptionV3 アーキテクチャの最後の畳み込み層である tf.keras モデルを作成しましょう。このレイヤーの出力の shape は```8x8x2048``` です。\n",
        "\n",
        "* 画像を 1 枚ずつネットワークに送り込み、処理結果のベクトルをディクショナリに保管します (image_name --> feature_vector)\n",
        "* すべての画像をネットワークで処理したあと、ディクショナリを pickle 形式でディスクに書き出します\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RD3vW4SsRPFW"
      },
      "outputs": [],
      "source": [
        "image_model = tf.keras.applications.InceptionV3(include_top=False,\n",
        "                                                weights='imagenet')\n",
        "new_input = image_model.input\n",
        "hidden_layer = image_model.layers[-1].output\n",
        "\n",
        "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rERqlR3WRPGO"
      },
      "source": [
        "## InceptionV3 から抽出した特徴量のキャッシング\n",
        "\n",
        "それぞれの画像を InceptionV3 で前処理して出力をディスクにキャッシュします。出力を RAM にキャッシュすれば速くなりますが、画像 1 枚あたり 8 \\* 8 \\* 2048 個の浮動小数点数が必要で、メモリを大量に必要とします。これを書いている時点では、これは Colab のメモリ上限（現在は 12GB）を超えています。\n",
        "\n",
        "より高度なキャッシング戦略（たとえば画像を分散保存しディスクのランダムアクセス入出力を低減するなど）を使えば性能は向上しますが、より多くのコーディングが必要となります。\n",
        "\n",
        "このキャッシングは Colab で GPU を使った場合で約 10 分ほどかかります。プログレスバーを表示したければ次のようにします。\n",
        "\n",
        "1. [tqdm](https://github.com/tqdm/tqdm) をインストールします:\n",
        "\n",
        "    `!pip install tqdm`\n",
        "\n",
        "2. tqdm をインポートします:\n",
        "\n",
        "    `from tqdm import tqdm`\n",
        "\n",
        "3. 下記の行を:\n",
        "\n",
        "    `for img, path in image_dataset:`\n",
        "\n",
        "    次のように変更します:\n",
        "\n",
        "    `for img, path in tqdm(image_dataset):`\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Dx_fvbVgRPGQ"
      },
      "outputs": [],
      "source": [
        "# 重複のない画像を取得\n",
        "encode_train = sorted(set(img_name_vector))\n",
        "\n",
        "# batch_size はシステム構成に合わせて自由に変更可能\n",
        "image_dataset = tf.data.Dataset.from_tensor_slices(encode_train)\n",
        "image_dataset = image_dataset.map(\n",
        "  load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(16)\n",
        "\n",
        "for img, path in image_dataset:\n",
        "  batch_features = image_features_extract_model(img)\n",
        "  batch_features = tf.reshape(batch_features,\n",
        "                              (batch_features.shape[0], -1, batch_features.shape[3]))\n",
        "\n",
        "  for bf, p in zip(batch_features, path):\n",
        "    path_of_feature = p.numpy().decode(\"utf-8\")\n",
        "    np.save(path_of_feature, bf.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nyqH3zFwRPFi"
      },
      "source": [
        "## キャプションの前処理とトークン化\n",
        "\n",
        "* まず最初に、キャプションをトークン化（たとえばスペースで区切るなど）します。これにより、データ中の重複しない単語のボキャブラリ（たとえば、\"surfing\"、\"football\" など）が得られます。\n",
        "* つぎに、（メモリ節約のため）ボキャブラリのサイズを上位 5,000 語に制限します。それ以外の単語は \"UNK\" （不明）というトークンに置き換えます。\n",
        "* 続いて、単語からインデックス、インデックスから単語への対応表を作成します。\n",
        "* 最後に、すべてのシーケンスの長さを最長のものに合わせてパディングします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HZfK8RhQRPFj"
      },
      "outputs": [],
      "source": [
        "# データセット中の一番長いキャプションの長さを検出\n",
        "def calc_max_length(tensor):\n",
        "    return max(len(t) for t in tensor)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oJGE34aiRPFo"
      },
      "outputs": [],
      "source": [
        "# ボキャブラリ中のトップ 5000 語を選択\n",
        "top_k = 5000\n",
        "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k,\n",
        "                                                  oov_token=\"<unk>\",\n",
        "                                                  filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n",
        "tokenizer.fit_on_texts(train_captions)\n",
        "train_seqs = tokenizer.texts_to_sequences(train_captions)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8Q44tNQVRPFt"
      },
      "outputs": [],
      "source": [
        "tokenizer.word_index['<pad>'] = 0\n",
        "tokenizer.index_word[0] = '<pad>'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0fpJb5ojRPFv"
      },
      "outputs": [],
      "source": [
        "# トークン化したベクトルを生成\n",
        "train_seqs = tokenizer.texts_to_sequences(train_captions)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AidglIZVRPF4"
      },
      "outputs": [],
      "source": [
        "# キャプションの最大長に各ベクトルをパディング\n",
        "# max_length を指定しない場合、pad_sequences は自動的に計算\n",
        "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gL0wkttkRPGA"
      },
      "outputs": [],
      "source": [
        "# アテンションの重みを格納するために使われる max_length を計算\n",
        "max_length = calc_max_length(train_seqs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M3CD75nDpvTI"
      },
      "source": [
        "## データを訓練用とテスト用に分割"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iS7DDMszRPGF"
      },
      "outputs": [],
      "source": [
        "# 訓練用セットと検証用セットを 80-20 に分割して生成\n",
        "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector,\n",
        "                                                                    cap_vector,\n",
        "                                                                    test_size=0.2,\n",
        "                                                                    random_state=0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XmViPkRFRPGH"
      },
      "outputs": [],
      "source": [
        "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uEWM9xrYcg45"
      },
      "source": [
        "## 訓練用の tf.data データセットの作成"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "horagNvhhZiy"
      },
      "source": [
        "画像とキャプションが用意できました。つぎは、モデルの訓練に使用する tf.data データセットを作成しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q3TnZ1ToRPGV"
      },
      "outputs": [],
      "source": [
        "# これらのパラメータはシステム構成に合わせて自由に変更してください\n",
        "\n",
        "BATCH_SIZE = 64\n",
        "BUFFER_SIZE = 1000\n",
        "embedding_dim = 256\n",
        "units = 512\n",
        "vocab_size = len(tokenizer.word_index) + 1\n",
        "num_steps = len(img_name_train) // BATCH_SIZE\n",
        "# InceptionV3 から抽出したベクトルの shape は (64, 2048)\n",
        "# つぎの 2 つのパラメータはこのベクトルの shape を表す\n",
        "features_shape = 2048\n",
        "attention_features_shape = 64"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SmZS2N0bXG3T"
      },
      "outputs": [],
      "source": [
        "# numpy ファイルをロード\n",
        "def map_func(img_name, cap):\n",
        "  img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n",
        "  return img_tensor, cap"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FDF_Nm3tRPGZ"
      },
      "outputs": [],
      "source": [
        "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n",
        "\n",
        "# numpy ファイルを並列に読み込むために map を使用\n",
        "dataset = dataset.map(lambda item1, item2: tf.numpy_function(\n",
        "          map_func, [item1, item2], [tf.float32, tf.int32]),\n",
        "          num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "\n",
        "# シャッフルとバッチ化\n",
        "dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)\n",
        "dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nrvoDphgRPGd"
      },
      "source": [
        "## モデル\n",
        "\n",
        "興味深い事実：下記のデコーダは[Neural Machine Translation with Attention](../sequences/nmt_with_attention.ipynb) の例のデコーダとまったく同一です。\n",
        "\n",
        "このモデルのアーキテクチャは、[Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) の論文にインスパイアされたものです。\n",
        "\n",
        "* この例では、InceptionV3 の下層の畳込みレイヤーから特徴量を取り出します。得られるベクトルの shape は (8, 8, 2048) です。\n",
        "* このベクトルを (64, 2048) に変形します。\n",
        "* このベクトルは（1層の全結合層からなる）CNN エンコーダに渡されます。\n",
        "* RNN（ここではGRU）が画像を介して次の単語を予測します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ja2LFTMSdeV3"
      },
      "outputs": [],
      "source": [
        "class BahdanauAttention(tf.keras.Model):\n",
        "  def __init__(self, units):\n",
        "    super(BahdanauAttention, self).__init__()\n",
        "    self.W1 = tf.keras.layers.Dense(units)\n",
        "    self.W2 = tf.keras.layers.Dense(units)\n",
        "    self.V = tf.keras.layers.Dense(1)\n",
        "\n",
        "  def call(self, features, hidden):\n",
        "    # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n",
        "\n",
        "    # hidden shape == (batch_size, hidden_size)\n",
        "    # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n",
        "    hidden_with_time_axis = tf.expand_dims(hidden, 1)\n",
        "\n",
        "    # score shape == (batch_size, 64, hidden_size)\n",
        "    score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n",
        "\n",
        "    # attention_weights shape == (batch_size, 64, 1)\n",
        "    # score を self.V に適用するので、最後の軸は 1 となる\n",
        "    attention_weights = tf.nn.softmax(self.V(score), axis=1)\n",
        "\n",
        "    # 合計をとったあとの　context_vector の shpae == (batch_size, hidden_size)\n",
        "    context_vector = attention_weights * features\n",
        "    context_vector = tf.reduce_sum(context_vector, axis=1)\n",
        "\n",
        "    return context_vector, attention_weights"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AZ7R1RxHRPGf"
      },
      "outputs": [],
      "source": [
        "class CNN_Encoder(tf.keras.Model):\n",
        "    # すでに特徴量を抽出して pickle 形式でダンプしてあるので\n",
        "    # このエンコーダはそれらの特徴量を全結合層に渡して処理する\n",
        "    def __init__(self, embedding_dim):\n",
        "        super(CNN_Encoder, self).__init__()\n",
        "        # shape after fc == (batch_size, 64, embedding_dim)\n",
        "        self.fc = tf.keras.layers.Dense(embedding_dim)\n",
        "\n",
        "    def call(self, x):\n",
        "        x = self.fc(x)\n",
        "        x = tf.nn.relu(x)\n",
        "        return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V9UbGQmERPGi"
      },
      "outputs": [],
      "source": [
        "class RNN_Decoder(tf.keras.Model):\n",
        "  def __init__(self, embedding_dim, units, vocab_size):\n",
        "    super(RNN_Decoder, self).__init__()\n",
        "    self.units = units\n",
        "\n",
        "    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
        "    self.gru = tf.keras.layers.GRU(self.units,\n",
        "                                   return_sequences=True,\n",
        "                                   return_state=True,\n",
        "                                   recurrent_initializer='glorot_uniform')\n",
        "    self.fc1 = tf.keras.layers.Dense(self.units)\n",
        "    self.fc2 = tf.keras.layers.Dense(vocab_size)\n",
        "\n",
        "    self.attention = BahdanauAttention(self.units)\n",
        "\n",
        "  def call(self, x, features, hidden):\n",
        "    # アテンションを別のモデルとして定義\n",
        "    context_vector, attention_weights = self.attention(features, hidden)\n",
        "\n",
        "    # embedding 層を通過したあとの x の shape == (batch_size, 1, embedding_dim)\n",
        "    x = self.embedding(x)\n",
        "\n",
        "    # 結合後の x の shape == (batch_size, 1, embedding_dim + hidden_size)\n",
        "    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
        "\n",
        "    # 結合したベクトルを GRU に渡す\n",
        "    output, state = self.gru(x)\n",
        "\n",
        "    # shape == (batch_size, max_length, hidden_size)\n",
        "    x = self.fc1(output)\n",
        "\n",
        "    # x shape == (batch_size * max_length, hidden_size)\n",
        "    x = tf.reshape(x, (-1, x.shape[2]))\n",
        "\n",
        "    # output shape == (batch_size * max_length, vocab)\n",
        "    x = self.fc2(x)\n",
        "\n",
        "    return x, state, attention_weights\n",
        "\n",
        "  def reset_state(self, batch_size):\n",
        "    return tf.zeros((batch_size, self.units))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qs_Sr03wRPGk"
      },
      "outputs": [],
      "source": [
        "encoder = CNN_Encoder(embedding_dim)\n",
        "decoder = RNN_Decoder(embedding_dim, units, vocab_size)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-bYN7xA0RPGl"
      },
      "outputs": [],
      "source": [
        "optimizer = tf.keras.optimizers.Adam()\n",
        "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n",
        "    from_logits=True, reduction='none')\n",
        "\n",
        "def loss_function(real, pred):\n",
        "  mask = tf.math.logical_not(tf.math.equal(real, 0))\n",
        "  loss_ = loss_object(real, pred)\n",
        "\n",
        "  mask = tf.cast(mask, dtype=loss_.dtype)\n",
        "  loss_ *= mask\n",
        "\n",
        "  return tf.reduce_mean(loss_)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6A3Ni64joyab"
      },
      "source": [
        "## チェックポイント"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PpJAqPMWo0uE"
      },
      "outputs": [],
      "source": [
        "checkpoint_path = \"./checkpoints/train\"\n",
        "ckpt = tf.train.Checkpoint(encoder=encoder,\n",
        "                           decoder=decoder,\n",
        "                           optimizer = optimizer)\n",
        "ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fUkbqhc_uObw"
      },
      "outputs": [],
      "source": [
        "start_epoch = 0\n",
        "if ckpt_manager.latest_checkpoint:\n",
        "  start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PHod7t72RPGn"
      },
      "source": [
        "## 訓練\n",
        "\n",
        "* それぞれの `.npy` ファイルに保存されている特徴量を取り出し、エンコーダに渡す\n",
        "* エンコーダの出力と（0 で初期化された）隠れ状態と、デコーダの入力（開始トークン）をデコーダに渡す\n",
        "* デコーダは予測値とデコーダの隠れ状態を返す\n",
        "* デコーダの隠れ状態はモデルに戻され、予測値をつかって損失を計算する\n",
        "* デコーダの次の入力を決めるために teacher forcing を用いる\n",
        "* Teacher forcing は、デコーダの次の入力として正解の単語を渡す手法である\n",
        "* 最後のステップは、勾配を計算してそれをオプティマイザに適用し誤差逆伝播を行うことである"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vt4WZ5mhJE-E"
      },
      "outputs": [],
      "source": [
        "# 訓練用のセルを複数回実行すると loss_plot 配列がリセットされてしまうので\n",
        "# 独立したセルとして追加\n",
        "loss_plot = []"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sqgyz2ANKlpU"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def train_step(img_tensor, target):\n",
        "  loss = 0\n",
        "\n",
        "  # バッチごとに隠れ状態を初期化\n",
        "  # 画像のキャプションはその前後の画像と無関係なため\n",
        "  hidden = decoder.reset_state(batch_size=target.shape[0])\n",
        "\n",
        "  dec_input = tf.expand_dims([tokenizer.word_index['<start>']] * target.shape[0], 1)\n",
        "\n",
        "  with tf.GradientTape() as tape:\n",
        "      features = encoder(img_tensor)\n",
        "\n",
        "      for i in range(1, target.shape[1]):\n",
        "          # 特徴量をデコーダに渡す\n",
        "          predictions, hidden, _ = decoder(dec_input, features, hidden)\n",
        "\n",
        "          loss += loss_function(target[:, i], predictions)\n",
        "\n",
        "          # teacher forcing を使用\n",
        "          dec_input = tf.expand_dims(target[:, i], 1)\n",
        "\n",
        "  total_loss = (loss / int(target.shape[1]))\n",
        "\n",
        "  trainable_variables = encoder.trainable_variables + decoder.trainable_variables\n",
        "\n",
        "  gradients = tape.gradient(loss, trainable_variables)\n",
        "\n",
        "  optimizer.apply_gradients(zip(gradients, trainable_variables))\n",
        "\n",
        "  return loss, total_loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UlA4VIQpRPGo"
      },
      "outputs": [],
      "source": [
        "EPOCHS = 20\n",
        "\n",
        "for epoch in range(start_epoch, EPOCHS):\n",
        "    start = time.time()\n",
        "    total_loss = 0\n",
        "\n",
        "    for (batch, (img_tensor, target)) in enumerate(dataset):\n",
        "        batch_loss, t_loss = train_step(img_tensor, target)\n",
        "        total_loss += t_loss\n",
        "\n",
        "        if batch % 100 == 0:\n",
        "            print ('Epoch {} Batch {} Loss {:.4f}'.format(\n",
        "              epoch + 1, batch, batch_loss.numpy() / int(target.shape[1])))\n",
        "    # 後ほどグラフ化するためにエポックごとに損失を保存\n",
        "    loss_plot.append(total_loss / num_steps)\n",
        "\n",
        "    if epoch % 5 == 0:\n",
        "      ckpt_manager.save()\n",
        "\n",
        "    print ('Epoch {} Loss {:.6f}'.format(epoch + 1,\n",
        "                                         total_loss/num_steps))\n",
        "    print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1Wm83G-ZBPcC"
      },
      "outputs": [],
      "source": [
        "plt.plot(loss_plot)\n",
        "plt.xlabel('Epochs')\n",
        "plt.ylabel('Loss')\n",
        "plt.title('Loss Plot')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xGvOcLQKghXN"
      },
      "source": [
        "## キャプション！\n",
        "\n",
        "* 評価関数は訓練ループとおなじだが、teacher forcing は使わない。タイムステップごとのデコーダへの入力は、隠れ状態とエンコーダの入力に加えて、一つ前の予測値である。\n",
        "* モデルが終了トークンを予測したら、予測を終了する。\n",
        "* それぞれのタイムステップごとに、アテンションの重みを保存する。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RCWpDtyNRPGs"
      },
      "outputs": [],
      "source": [
        "def evaluate(image):\n",
        "    attention_plot = np.zeros((max_length, attention_features_shape))\n",
        "\n",
        "    hidden = decoder.reset_state(batch_size=1)\n",
        "\n",
        "    temp_input = tf.expand_dims(load_image(image)[0], 0)\n",
        "    img_tensor_val = image_features_extract_model(temp_input)\n",
        "    img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n",
        "\n",
        "    features = encoder(img_tensor_val)\n",
        "\n",
        "    dec_input = tf.expand_dims([tokenizer.word_index['<start>']], 0)\n",
        "    result = []\n",
        "\n",
        "    for i in range(max_length):\n",
        "        predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n",
        "\n",
        "        attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n",
        "\n",
        "        predicted_id = tf.random.categorical(predictions, 1)[0][0].numpy()\n",
        "        result.append(tokenizer.index_word[predicted_id])\n",
        "\n",
        "        if tokenizer.index_word[predicted_id] == '<end>':\n",
        "            return result, attention_plot\n",
        "\n",
        "        dec_input = tf.expand_dims([predicted_id], 0)\n",
        "\n",
        "    attention_plot = attention_plot[:len(result), :]\n",
        "    return result, attention_plot"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fD_y7PD6RPGt"
      },
      "outputs": [],
      "source": [
        "def plot_attention(image, result, attention_plot):\n",
        "    temp_image = np.array(Image.open(image))\n",
        "\n",
        "    fig = plt.figure(figsize=(10, 10))\n",
        "\n",
        "    len_result = len(result)\n",
        "    for l in range(len_result):\n",
        "        temp_att = np.resize(attention_plot[l], (8, 8))\n",
        "        ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n",
        "        ax.set_title(result[l])\n",
        "        img = ax.imshow(temp_image)\n",
        "        ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n",
        "\n",
        "    plt.tight_layout()\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7x8RiPHe_4qI"
      },
      "outputs": [],
      "source": [
        "# 検証用セットのキャプション\n",
        "rid = np.random.randint(0, len(img_name_val))\n",
        "image = img_name_val[rid]\n",
        "real_caption = ' '.join([tokenizer.index_word[i] for i in cap_val[rid] if i not in [0]])\n",
        "result, attention_plot = evaluate(image)\n",
        "\n",
        "print ('Real Caption:', real_caption)\n",
        "print ('Prediction Caption:', ' '.join(result))\n",
        "plot_attention(image, result, attention_plot)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rprk3HEvZuxb"
      },
      "source": [
        "## あなた独自の画像でためそう\n",
        "\n",
        "お楽しみのために、訓練したばかりのモデルであなたの独自の画像を使うためのメソッドを下記に示します。比較的少量のデータで訓練していること、そして、あなたの画像は訓練データとは異なるであろうことを、心に留めておいてください（変な結果が出てくることを覚悟しておいてください）。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9Psd1quzaAWg"
      },
      "outputs": [],
      "source": [
        "image_url = 'https://tensorflow.org/images/surf.jpg'\n",
        "image_extension = image_url[-4:]\n",
        "image_path = tf.keras.utils.get_file('image'+image_extension,\n",
        "                                     origin=image_url)\n",
        "\n",
        "result, attention_plot = evaluate(image_path)\n",
        "print ('Prediction Caption:', ' '.join(result))\n",
        "plot_attention(image_path, result, attention_plot)\n",
        "# 画像を開く\n",
        "Image.open(image_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VJZXyJco6uLO"
      },
      "source": [
        "# 次のステップ\n",
        "\n",
        "おめでとうございます！アテンション付きの画像キャプショニングモデルの訓練が終わりました。つぎは、この例 [Neural Machine Translation with Attention](../sequences/nmt_with_attention.ipynb) を御覧ください。これはおなじアーキテクチャをつかってスペイン語と英語の間の翻訳を行います。このノートブックのコードをつかって、別のデータセットで訓練を行うこともできます。"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "image_captioning.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
